from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple
from numpy.typing import NDArray
import pandas as pd
import networkx as nx
import numpy as np

class CustomCausalModel(ABC):
    """
    Abstract base class for custom causal models.

    This class provides an interface for defining and fitting custom causal models.
    Subclasses must implement the `fit` and `estimate_effect` methods.

    Attributes:
        None

    Methods:
        __init__(): Initializes the custom causal model.
        fit(obs_data, int_data, method_params, seed, save_dir): Fits the causal model using the provided observational data.
        identify_effect(treatment, outcome, obs_data, int_data, method_params, seed, save_dir): Identifies the causal effect of a treatment on an outcome.
        estimate_effect(outcome, treatment, evidence, method_params, seed, save_dir): Estimates the causal effect of a treatment on an outcome using the given evidence.

    """

    @abstractmethod
    def __init__(self, causal_graph: nx.Graph):
        pass

    @abstractmethod
    def identify_effect(
        self,
        treatment: Optional[dict[str, float]] = {},
        outcome: Optional[dict[str, float]] = {},
        obs_data: Optional[pd.DataFrame] = None,
        int_data: Optional[pd.DataFrame] = None,
        method_params: dict[str, Any] = {},
        seed: Optional[int] = None,
        save_dir: Optional[str] = None,
    ) -> dict[str, Any]:
        pass

    @abstractmethod
    def fit(
        self,
        data: Optional[pd.DataFrame] = None,
        int_table: Optional[pd.DataFrame] = None,
        method_params: dict[str, Any] = {},
        seed: Optional[int] = None,
        save_dir: Optional[str] = None,
        *kwargs,
    ) -> dict[str, Any]:
        """
        Fits the causal model using the provided observational data.

        Args:
            data (pd.DataFrame): A pandas DataFrame containing the observational data. Defaults to None.
            int_table (pd.DataFrame): A pandas DataFrame containing the interventional table, with 1 where the intervention is active, and 0 otherwise. Defaults to None.
            method_params (dict[str, Any], optional): Additional parameters for the fitting method. Defaults to {}.
            seed (int, optional): Random seed for reproducibility. Defaults to None.
            save_dir (str, optional): Directory to save the fitted model. Defaults to None.

        Returns:
            dict[str, Any]: A dictionary containing the fitted model parameters and other relevant information.
        """
        pass

    @abstractmethod
    def estimate_effect(
        self,
        outcome: str,
        treatment: Optional[dict[str, float]] = {},
        evidence: dict[str, float] = {},
        method_params: dict[str, Any] = {},
        seed: Optional[int] = None,
        save_dir: Optional[str] = None,
        data: Optional[pd.DataFrame] = None,
        int_table: Optional[pd.DataFrame] = None,
    ) -> tuple[dict[str, Any], dict[str, Any]]:
        """
        Estimates the causal effect of a treatment on an outcome using the given evidence.

        Args:
            outcome (str): The name of the outcome variable.
            treatment (Optional[dict[str, float]]): A dictionary representing the treatment variables and their values. Defaults to an empty dictionary.
            evidence (dict[str, float]): A dictionary representing the evidence variables and their values. Defaults to an empty dictionary.
            method_params (dict[str, Any]): A dictionary of additional parameters for the estimation method. Defaults to an empty dictionary.
            seed (int, optional): Random seed for reproducibility. Defaults to None.
            save_dir (str, optional): Directory to save the estimation results. Defaults to None.
            data (pd.DataFrame): A pandas DataFrame containing the observational data. Defaults to None.
            int_table (pd.DataFrame): A pandas DataFrame containing the interventional table, with 1 where the intervention is active, and 0 otherwise. Defaults to None.

        Returns:
            tuple[dict[str, Any], dict[str, Any]]: A tuple containing two dictionaries.
            The first dictionary represents the estimated causal quantities, and the second dictionary contains additional information or statistics related to the estimation such as runtime and memory usage.

        """
        pass

    def extract_treat_control(self, treatment: Dict[str, Any]):
        treatment_dict = {treatment["treatment_var"]: treatment["treatment_value"]}
        control_dict = {treatment["control_var"]: treatment["control_value"]}
        return treatment_dict,control_dict
    
    def monitor_gpu_usage(self, model):
        """
        Monitor the GPU memory usage of the model.

        Args:
            model (torch.nn.Module): The PyTorch model to monitor.

        Returns:
            int
            
        """
        mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()])
        mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
        mem = mem_params + mem_bufs # in bytes

        return mem
    
    def condition_and_quantize(self, treated_samples, control_samples, evidence: dict = {}, data = None, quantize=False):
        # Function to discretize a single value
        def discretize_value(value, bins):
            # Find the index of the closest bin
            index = np.argmin(np.abs(bins - value))
            # Return the value of the closest bin
            return bins[index]
        
        for col in evidence:
            if quantize:
                # Bin the continuous values correspondent to the evidence variable. Bin them corresponding on the values found in data
                if data is not None:
                    if data[col].nunique() < 50:
                        bins = np.unique(data[col])

                        control_samples[col] = control_samples[col].apply(discretize_value, bins=bins)
                        treated_samples[col] = treated_samples[col].apply(discretize_value, bins=bins)
                else:
                    raise ValueError(
                        "The data parameter is required to bin continuous values."
                    )

            ev_value = evidence[col]
            control_mask = control_samples[col] == ev_value
            control_samples = control_samples[control_mask]

            treated_mask = treated_samples[col] == ev_value
            treated_samples = treated_samples[treated_mask]
    
        return treated_samples, control_samples
    
    def get_probability_distribution(
        self, samples: NDArray[Any], bin_sequence: Optional[NDArray[Any]] = None
    ) -> Tuple[float, float, NDArray[Any]]:
        """
        Calculate the probability distribution of the given samples.

        Parameters:
        samples (NDArray[Any]): The array of samples.
        bin_sequence (Optional[NDArray[Any]]): The sequence of bin edges for binning the samples. Default is None.

        Returns:
        Tuple[float,float, NDArray[Any]]: A tuple containing the bin centers, probabilities, and bin edges.

        Raises:
        None

        """

        # Check the number of unique values in the samples array
        unique_values, unique_counts = np.unique(samples, return_counts=True)
        num_unique_values = len(unique_values)
        
        # Given a numpy array, check if it contains only binary values
        if num_unique_values == 1:
            if unique_values[0] == -1:
                return [-1, 1], [1, 0], [-1, 1]
            else:
                return [-1, 1], [0, 1], [-1, 1]
        elif num_unique_values == 2:
            # If the array contains only binary values, calculate the probability distribution
            sum_counts = np.sum(unique_counts)
            probs = unique_counts / sum_counts
            return unique_values, probs, [-1, 1]
        else:
            # Bin the samples into 100 bins. Adjust the bins as necessary for your data.
            # Setting density=True converts the histogram to a probability density.
            if bin_sequence is not None:
                counts, bin_edges = np.histogram(
                    samples, bins=bin_sequence, density=True
                )
            else:
                nbins = 50 if num_unique_values > 50 else num_unique_values
                if nbins == 0:
                    raise ValueError("The number of unique values in the samples array is 0.")
                counts, bin_edges = np.histogram(samples, bins=nbins, density=True)

            # Calculate the bin centers
            bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

            # Compute the distribution
            sum_counts = np.sum(counts)
            probs = counts / sum_counts

            return bin_centers, probs, bin_edges

    @staticmethod
    def parse_vars(df: pd.DataFrame) -> list[str, str, str]:
        """Parse variables into continuous, discrete and binary variables

        Args:
            df (pd.DataFrame): Dataframe to parse.

        Returns:
            list[str,str,str]: List of continuous, discrete and binary variables.
        """
        cont_var = []
        dis_var = []
        bin_var = []
        for col in df.columns:
            if df[col].nunique() > 50:
                cont_var.append(col)
            elif df[col].nunique() == 2:
                bin_var.append(col)
            else:
                dis_var.append(col)
        return cont_var, dis_var, bin_var